library(tidyverse)
data <- read.csv("india_subset.csv")

## get rid of the first column
data <- data[,-1]
colnames(data) <- c("id","A","Z","Y1","Y2")

data <- data[!is.na(data$Y1)& !is.na(data$Y2),]


### impute missing A 
mis.A <- data[is.na(data$A),1]
for (id in mis.A){
  data[ is.na(data$A)& data$id==id ,] <- mean(data[data$id==id,3],na.rm=TRUE) 	
}

### get rid of Z== 5
data <- data[data$Z!=5,]


### combine treamtent A B C as 1 
data$Z <- 1*(data$Z==1 | data$Z==2 | data$Z==3)

## 1 for high  and 2 for low
data$A[data$A==1 |data$A==2 | data$A==4] <- 1
data$A[data$A==5] <- 2
data$A[data$A==3] <- 3


length(unique(data$id))


### remove the clusters with only <=1 treated unit or <=1 controlled unit

n1 <- tapply(data$Z,data$id,sum)
n0 <- tapply(1-data$Z,data$id,sum)

data <- filter(data,  !(id %in%     as.numeric(c(names(which(n1<=1)),names(which(n0<=1)))) )   )

# A.cluster <- tapply(data$A,data$id,mean)

table(data$A)

## contrast matrix
m=3
qa <- c(0.65,0.15,0.2)
C1 = array(0,dim=c(m,2*m))
C2 = rep(0,2*m)

for (a in 1:m){
  C1[a,2*a-1] = 1
  C1[a,2*a] = -1
  C2[2*a-1] = qa[a]
  C2[2*a] = -qa[a]
}

C3 = array(0,dim=c(2*m-2,2*m))
for ( a in 1:(m-1)){
  C3[a,2*a-1]=1
  C3[a,2*a+1]=-1
  C3[m-1+a,2*a]=1
  C3[m-1+a,2*a+2]=-1
}


### midline hospitalization

est.midline <- CalAPO_df(data$id,data$Z,data$A,data$Y1)

de.midline<- C1%*%est.midline$Y.hat
sd.de.midline<- sqrt(diag(C1%*%est.midline$cov.hat%*%t(C1)))

lower.de.midline <- de.midline - 1.96*sd.de.midline
upper.de.midline <- de.midline + 1.96*sd.de.midline

## TRUE means rejected
# Test2SRE(Z,A,Y.LTFC,effect="DE")

mde.midline<- drop(C2%*%est.midline$Y.hat)
sd.mde.midline <- drop(sqrt(t(C2)%*%est.midline$cov.hat%*%(C2)))
lower.mde.midline <- mde.midline - 1.96*sd.mde.midline
upper.mde.midline <- mde.midline + 1.96*sd.mde.midline

# Test2SRE(Z,A,Y.LTFC,effect="MDE")

## 
se.midline <- C3%*%est.midline$Y.hat
sd.se.midline<-sqrt(diag(C3%*%est.midline$cov.hat%*%t(C3)))
lower.se.midline <-  se.midline-1.96*sd.se.midline
upper.se.midline <-  se.midline+1.96*sd.se.midline

# Test2SRE(Z,A,Y.LTFC,effect="SE")


###############
### endline hospitalization

est.endline <- CalAPO_df(data$id,data$Z,data$A,data$Y2)

de.endline<- C1%*%est.endline$Y.hat
sd.de.endline<- sqrt(diag(C1%*%est.endline$cov.hat%*%t(C1)))

lower.de.endline <- de.endline - 1.96*sd.de.endline
upper.de.endline <- de.endline + 1.96*sd.de.endline


## heterogeneity
c.1v3 <- c(1,-1,0,0,-1,1)
de.endline[1]-de.endline[3]-1.96*sqrt( t(c.1v3)%*%est.endline$cov.hat%*%(c.1v3))
de.endline[1]-de.endline[3]+1.96*sqrt( t(c.1v3)%*%est.endline$cov.hat%*%(c.1v3))

## TRUE means rejected
# Test2SRE(Z,A,Y.LTFC,effect="DE")

mde.endline<- drop(C2%*%est.endline$Y.hat)
sd.mde.endline <- drop(sqrt(t(C2)%*%est.endline$cov.hat%*%(C2)))
lower.mde.endline <- mde.endline - 1.96*sd.mde.endline
upper.mde.endline <- mde.endline + 1.96*sd.mde.endline

# Test2SRE(Z,A,Y.LTFC,effect="MDE")

## 
se.endline <- C3%*%est.endline$Y.hat
sd.se.endline<-sqrt(diag(C3%*%est.endline$cov.hat%*%t(C3)))
lower.se.endline <-  se.endline-1.96*sd.se.endline
upper.se.endline <-  se.endline+1.96*sd.se.endline



#### display the result

display.midline <- tibble(
  est = c(de.midline,mde.midline,se.midline),
  lower = c(lower.de.midline,lower.mde.midline,lower.se.midline),
  upper = c(upper.de.midline,upper.mde.midline,upper.se.midline),
  y = c(12,11,10,7,4,3,2,1),
  effect ="Midline hopitalization"
)

display.endline <- tibble(
  est = c(de.endline,mde.endline,se.endline),
  lower = c(lower.de.endline,lower.mde.endline,lower.se.endline),
  upper = c(upper.de.endline,upper.mde.endline,upper.se.endline),
  y = c(12,11,10,7,4,3,2,1),
  effect ="Endline hopitalization"
)

display <- rbind(display.midline,display.endline)
display$effect <- factor(display$effect, ordered=T, levels = c("Midline hopitalization", "Endline hopitalization"))

ggplot(data = display)+
  geom_point(aes(x = est, y =y))+
  geom_errorbar(aes(y=y,xmin = lower, xmax = upper),width=0.2)+
  geom_vline(xintercept=0,color="grey",linetype=2)+
  scale_x_continuous(name = NULL, labels=scales::percent)+
  scale_y_continuous(name=NULL, 
                     labels = c("ASE(0;2,3)","ASE(0;1,2)","ASE(1;2,3)","ASE(1;1,2)","MDE","ADE(3)","ADE(2)","ADE(1)"),
                     breaks=c(1:4,7,10:12))+
  theme_bw()+
  facet_wrap(~effect)

ggsave("/Users/Zhichao/Dropbox/github/kosuke/mismatch/interference2/paper/figs/india_analysis.pdf",height=4,width=8)


### effect size
mu <- 0.05
pa <- c( 0.9,0.7,0.5)


var.midline <- Calpara(data$id,data$Z,data$A,data$Y1)
var.midline$r
n.bar <- round(var.midline$n.bar)
sigma.midline <- var.midline$sigmab+var.midline$sigmaw
round(Calsamplesize2(mu,n.bar,qa, pa, r=var.midline$r, sigma.midline, alpha=0.05, beta=0.2))

var.endline <- Calpara(data$id,data$Z,data$A,data$Y2)
var.endline$r
sigma.endline <- var.endline$sigmab+var.endline$sigmaw
n.bar <- round(var.endline$n.bar)
round(Calsamplesize2(mu,n.bar,qa, pa, r=var.endline$r, sigma.endline, alpha=0.05, beta=0.2))


